"""Modular cabinets."""
from abc import ABC
from pathlib import Path
from typing import Optional

import numpy as np
from dm_control import mjcf
from mojo import Mojo
from mojo.elements import Body, Geom, MujocoElement

from bigym.const import ASSETS_PATH
from bigym.envs.props.holders import CutleryTray
from bigym.envs.props.prop import Prop
from bigym.utils.physics_utils import set_joint_position, get_joint_position


class ModularProp(Prop, ABC):
    """Base modular prop."""

    def __init__(self, mojo: Mojo, model_path: Path, cache_sites: bool):
        """Init."""
        super().__init__(mojo, model_path, cache_sites=cache_sites)
        self._joints = self.body.joints

    def set_state(self, state: np.ndarray):
        """Set normalized state of joints."""
        for value, joint in zip(state, self._joints):
            set_joint_position(joint, value, True)

    def get_state(self) -> np.ndarray[float]:
        """Get normalized state of joints."""
        return np.array([get_joint_position(joint, True) for joint in self._joints])

    @staticmethod
    def _configure_body(model: mjcf.RootElement, name, enable):
        body = model.find("body", name)
        if not enable:
            body.remove()


BASE_CABINET_PATH = ASSETS_PATH / "props/kitchen/base_cabinet_600.xml"


class BaseCabinet(ModularProp):
    """Modular base cabinet."""

    _BIG_DRAWER_NAMES = ["drawer_big_1", "drawer_big_2"]
    _SMALL_DRAWER_NAMES = [
        "drawer_small_1",
        "drawer_small_2",
        "drawer_small_3",
        "drawer_small_4",
    ]
    _DOOR_LEFT = "door_left"
    _DOOR_RIGHT = "door_right"
    _PANEL = "panel"
    _SHELF = "shelf"
    _SHELF_BOTTOM = "shelf_bottom"
    _HOB = "hob"
    _WALLS = "walls"
    _COUNTER = "counter"

    def __init__(
        self,
        mojo: Mojo,
        walls_enable: bool = True,
        big_drawers_enable: list[bool] = None,
        small_drawers_enable: list[bool] = None,
        hob_enable: bool = False,
        panel_enable: bool = False,
        door_left_enable: bool = False,
        door_right_enable: bool = False,
        shelf_enable: bool = False,
    ):
        """Init."""
        if big_drawers_enable is None:
            big_drawers_enable = [False] * len(self._BIG_DRAWER_NAMES)
        if small_drawers_enable is None:
            small_drawers_enable = [False] * len(self._SMALL_DRAWER_NAMES)

        has_drawers = any(big_drawers_enable) or any(small_drawers_enable)
        if has_drawers:
            shelf_enable = False
            door_right_enable = False
            door_left_enable = False
            panel_enable = False
        if panel_enable:
            door_left_enable = False
            door_right_enable = False

        assert len(big_drawers_enable) == len(self._BIG_DRAWER_NAMES)
        assert len(small_drawers_enable) == len(self._SMALL_DRAWER_NAMES)

        self._walls_enable = walls_enable
        self._big_drawers_enable = big_drawers_enable
        self._small_drawers_enable = small_drawers_enable
        self._hob_enable = hob_enable
        self._panel_enable = panel_enable
        self._door_left_enable = door_left_enable
        self._door_right_enable = door_right_enable
        self._shelf_enable = shelf_enable

        self.shelf: Optional[Geom] = None
        self.shelf_bottom: Optional[Geom] = None
        self.counter: Optional[Geom] = None

        super().__init__(mojo, BASE_CABINET_PATH, has_drawers or hob_enable)

    def _on_loaded(self, model: mjcf.RootElement):
        cabinet = MujocoElement(self._mojo, model)
        self.shelf = Body.get(self._mojo, self._SHELF, cabinet).geoms[-1]
        self.shelf_bottom = Body.get(self._mojo, self._SHELF_BOTTOM, cabinet).geoms[-1]
        self.counter = Body.get(self._mojo, self._COUNTER, cabinet).geoms[-1]
        self.hob = Body.get(self._mojo, self._HOB, cabinet).geoms[-1]

        for drawer_name, enable in zip(
            self._BIG_DRAWER_NAMES, self._big_drawers_enable
        ):
            self._configure_body(model, drawer_name, enable)

        for drawer_name, enable in zip(
            self._SMALL_DRAWER_NAMES, self._small_drawers_enable
        ):
            self._configure_body(model, drawer_name, enable)

        self._configure_body(model, self._WALLS, self._walls_enable)
        self._configure_body(model, self._HOB, self._hob_enable)
        self._configure_body(model, self._PANEL, self._panel_enable)
        self._configure_body(model, self._DOOR_LEFT, self._door_left_enable)
        self._configure_body(model, self._DOOR_RIGHT, self._door_right_enable)
        self._configure_body(model, self._SHELF, self._shelf_enable)


class BaseCabinetForCutlery(BaseCabinet):
    """Base cabinet preset with cutlery tray in the top drawer."""

    _TRAY_SITE_NAME = "drawer_small_4"

    def __init__(self, mojo: Mojo):
        """Init."""
        super().__init__(mojo, small_drawers_enable=[True, True, True, True])
        site = next(
            (site for site in self.sites if site.mjcf.name == self._TRAY_SITE_NAME),
            None,
        )
        self.tray = CutleryTray(self._mojo, site)


WALL_CABINET_PATH = ASSETS_PATH / "props/kitchen/wall_cabinet_600.xml"


class WallCabinet(ModularProp):
    """Modular wall cabinet."""

    _DOORS = ["door_right", "door_left"]
    _GLASS_DOORS = ["door_right_glass", "door_left_glass"]
    _VENT = "vent"
    _SHELF = "shelf"
    _SHELF_BOTTOM = "shelf_bottom"

    def __init__(
        self,
        mojo: Mojo,
        doors_enable: bool = False,
        glass_doors_enable: bool = False,
        vent_enable: bool = False,
    ):
        """Init."""
        self._doors_enable = doors_enable
        self._glass_doors_enable = glass_doors_enable
        self._vent_enable = vent_enable

        self.shelf: Optional[Geom] = None
        self.shelf_bottom: Optional[Geom] = None

        super().__init__(mojo, WALL_CABINET_PATH, False)

    def _on_loaded(self, model: mjcf.RootElement):
        cabinet = MujocoElement(self._mojo, model)
        self.shelf = Body.get(self._mojo, self._SHELF, cabinet).geoms[-1]
        self.shelf_bottom = Body.get(self._mojo, self._SHELF_BOTTOM, cabinet).geoms[-1]

        for door_name in self._DOORS:
            self._configure_body(model, door_name, self._doors_enable)
        for door_name in self._GLASS_DOORS:
            self._configure_body(model, door_name, self._glass_doors_enable)
        self._configure_body(model, self._VENT, self._vent_enable)


OPEN_SHELF_PATH = ASSETS_PATH / "props/kitchen/open_shelf_600.xml"


class OpenShelf(ModularProp):
    """Modular open shelf."""

    _SHELF = "shelf"

    def __init__(
        self,
        mojo: Mojo,
    ):
        """Init."""
        super().__init__(mojo, OPEN_SHELF_PATH, False)

    def _on_loaded(self, model: mjcf.RootElement):
        shelf = MujocoElement(self._mojo, model)
        self.shelf = Body.get(self._mojo, self._SHELF, shelf).geoms[-1]
